from copy import deepcopy
from functools import partial

import jax
import jax.numpy as jnp
import optax
import math

from flax.training.train_state import TrainState
from ml_collections import ConfigDict
from core.core_api import Algo
from utilities.jax_utils import mse_loss, next_rng, value_and_multi_grad

from diffusion.diffusion import GaussianDiffusion
import distrax
from algos.distributional_rl import QRAgent, C51Agent

def update_target_network(main_params, target_params, tau):
    return jax.tree_map(
        lambda x, y: tau * x + (1.0 - tau) * y, main_params, target_params
    )

 
class DiffusionQL(Algo):

    @staticmethod
    def get_default_config(updates=None):
        cfg = ConfigDict()
        cfg.nstep = 1
        cfg.discount = 0.99
        cfg.tau = 0.005
        cfg.policy_tgt_freq = 5
        cfg.num_timesteps = 100
        cfg.schedule_name = 'linear'
        cfg.time_embed_size = 16
        cfg.alpha = 2.  # NOTE 0.25 in diffusion rl but 2.5 in td3
        cfg.use_pred_astart = True
        cfg.max_q_backup = False
        cfg.max_q_backup_topk = 1
        cfg.max_q_backup_samples = 10
        cfg.guide_warmup = False
        cfg.diff_annealing = False

        # learning related
        cfg.lr = 3e-4
        cfg.diff_coef = 1.0
        cfg.guide_coef = 1.0
        cfg.lr_decay = False
        cfg.train_steps = 1000000
        cfg.lr_decay_steps = 1000000
        cfg.max_grad_norm = 0.
        cfg.weight_decay = 0.

        cfg.loss_type = 'TD3'
        cfg.target_clip = False
        cfg.trust_region_target = False
        cfg.MAX_Q = 0.0
        cfg.use_expectile = False  # False: CRR; True: IQL
        cfg.expectile_q = False  # use td of expectile v to estimate q

        cfg.adv_norm = False
        # CRR-related hps
        cfg.sample_actions = 20
        cfg.crr_weight_mode = 'mle'
        cfg.fixed_std = True
        cfg.crr_multi_sample_mse = False
        cfg.crr_avg_fn = 'mean'
        cfg.crr_fn = 'exp'

        # IQL-related hps
        cfg.expectile = 0.7

        # CRR and IQL shared hps
        cfg.crr_ratio_upper_bound = 20
        cfg.crr_beta = 1.0
        cfg.awr_temperature = 3.0

        # distributional RL hps
        # Following D4PG and QR-DDPG, use 51 bins for c51 and 201 bins for QR
        cfg.use_dist_rl = False
        cfg.dist_type = 'qr'  # c51 / qr
        cfg.num_atoms = 201  
        
        # reset
        cfg.reset_q = False
        cfg.reset_mode = 'all' # all / last / SP
        cfg.reset_actor = False
        cfg.reset_interval = 1000000
        cfg.max_tgt_q = False # update actor by maximizing target q
        
        # for dpm-solver
        cfg.dpm_steps = 15
        cfg.dpm_t_end = 0.001

        # useless
        cfg.target_entropy = -1
        if updates is not None:
            cfg.update(ConfigDict(updates).copy_and_resolve_references())
        return cfg
